# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for locally-connected layers."""


import os

import numpy as np
import tensorflow.compat.v2 as tf
from absl.testing import parameterized

import keras
from keras.layers.locally_connected import locally_connected_utils
from keras.optimizers.optimizer_v2 import rmsprop
from keras.testing_infra import test_combinations
from keras.testing_infra import test_utils

# isort: off
from tensorflow.python.framework import (
    test_util as tf_test_util,
)
from tensorflow.python.training.rmsprop import (
    RMSPropOptimizer,
)

_DATA_FORMAT_PADDING_IMPLEMENTATION = [
    {"data_format": "channels_first", "padding": "valid", "implementation": 1},
    {"data_format": "channels_first", "padding": "same", "implementation": 1},
    {"data_format": "channels_last", "padding": "valid", "implementation": 1},
    {"data_format": "channels_last", "padding": "same", "implementation": 1},
    {"data_format": "channels_first", "padding": "valid", "implementation": 2},
    {"data_format": "channels_first", "padding": "same", "implementation": 2},
    {"data_format": "channels_last", "padding": "valid", "implementation": 2},
    {"data_format": "channels_last", "padding": "same", "implementation": 2},
    {"data_format": "channels_first", "padding": "valid", "implementation": 3},
    {"data_format": "channels_first", "padding": "same", "implementation": 3},
    {"data_format": "channels_last", "padding": "valid", "implementation": 3},
    {"data_format": "channels_last", "padding": "same", "implementation": 3},
]


@test_combinations.generate(test_combinations.combine(mode=["graph", "eager"]))
class LocallyConnected1DLayersTest(tf.test.TestCase, parameterized.TestCase):
    @parameterized.parameters(_DATA_FORMAT_PADDING_IMPLEMENTATION)
    def test_locallyconnected_1d(self, data_format, padding, implementation):
        with self.cached_session():
            num_samples = 2
            num_steps = 8
            input_dim = 5
            filter_length = 3
            filters = 4

            for strides in [1]:
                if padding == "same" and strides != 1:
                    continue
                kwargs = {
                    "filters": filters,
                    "kernel_size": filter_length,
                    "padding": padding,
                    "strides": strides,
                    "data_format": data_format,
                    "implementation": implementation,
                }

                if padding == "same" and implementation == 1:
                    self.assertRaises(
                        ValueError, keras.layers.LocallyConnected1D, **kwargs
                    )
                else:
                    test_utils.layer_test(
                        keras.layers.LocallyConnected1D,
                        kwargs=kwargs,
                        input_shape=(num_samples, num_steps, input_dim),
                    )

    @parameterized.parameters(_DATA_FORMAT_PADDING_IMPLEMENTATION)
    def test_locallyconnected_1d_regularization(
        self, data_format, padding, implementation
    ):
        num_samples = 2
        num_steps = 8
        input_dim = 5
        filter_length = 3
        filters = 4
        kwargs = {
            "filters": filters,
            "kernel_size": filter_length,
            "kernel_regularizer": "l2",
            "bias_regularizer": "l2",
            "activity_regularizer": "l2",
            "data_format": data_format,
            "implementation": implementation,
            "padding": padding,
        }

        if padding == "same" and implementation == 1:
            self.assertRaises(
                ValueError, keras.layers.LocallyConnected1D, **kwargs
            )
        else:
            with self.cached_session():
                layer = keras.layers.LocallyConnected1D(**kwargs)
                layer.build((num_samples, num_steps, input_dim))
                self.assertLen(layer.losses, 2)
                layer(
                    keras.backend.variable(
                        np.ones((num_samples, num_steps, input_dim))
                    )
                )
                self.assertLen(layer.losses, 3)

            k_constraint = keras.constraints.max_norm(0.01)
            b_constraint = keras.constraints.max_norm(0.01)
            kwargs = {
                "filters": filters,
                "kernel_size": filter_length,
                "kernel_constraint": k_constraint,
                "bias_constraint": b_constraint,
            }
            with self.cached_session():
                layer = keras.layers.LocallyConnected1D(**kwargs)
                layer.build((num_samples, num_steps, input_dim))
                self.assertEqual(layer.kernel.constraint, k_constraint)
                self.assertEqual(layer.bias.constraint, b_constraint)

    def test_locallyconnected1d_invalid_output_shapes(self):
        kwargs = {"filters": 2, "kernel_size": 10}
        with self.assertRaisesRegex(
            ValueError, r"""One of the dimensions in the output is <= 0 """
        ):
            layer = keras.layers.LocallyConnected1D(**kwargs)
            layer.build((None, 5, 2))


@test_combinations.generate(test_combinations.combine(mode=["graph", "eager"]))
class LocallyConnected2DLayersTest(tf.test.TestCase, parameterized.TestCase):
    @parameterized.parameters(_DATA_FORMAT_PADDING_IMPLEMENTATION)
    def test_locallyconnected_2d(self, data_format, padding, implementation):
        with self.cached_session():
            num_samples = 8
            filters = 3
            stack_size = 4
            num_row = 6
            num_col = 10

            for strides in [(1, 1), (2, 2)]:
                if padding == "same" and strides != (1, 1):
                    continue

                kwargs = {
                    "filters": filters,
                    "kernel_size": 3,
                    "padding": padding,
                    "kernel_regularizer": "l2",
                    "bias_regularizer": "l2",
                    "strides": strides,
                    "data_format": data_format,
                    "implementation": implementation,
                }

                if padding == "same" and implementation == 1:
                    self.assertRaises(
                        ValueError, keras.layers.LocallyConnected2D, **kwargs
                    )
                else:
                    test_utils.layer_test(
                        keras.layers.LocallyConnected2D,
                        kwargs=kwargs,
                        input_shape=(num_samples, num_row, num_col, stack_size),
                    )

    @parameterized.parameters(_DATA_FORMAT_PADDING_IMPLEMENTATION)
    def test_locallyconnected_2d_channels_first(
        self, data_format, padding, implementation
    ):
        with self.cached_session():
            num_samples = 8
            filters = 3
            stack_size = 4
            num_row = 6
            num_col = 10
            kwargs = {
                "filters": filters,
                "kernel_size": 3,
                "data_format": data_format,
                "implementation": implementation,
                "padding": padding,
            }

            if padding == "same" and implementation == 1:
                self.assertRaises(
                    ValueError, keras.layers.LocallyConnected2D, **kwargs
                )
            else:
                test_utils.layer_test(
                    keras.layers.LocallyConnected2D,
                    kwargs=kwargs,
                    input_shape=(num_samples, num_row, num_col, stack_size),
                )

    @parameterized.parameters(_DATA_FORMAT_PADDING_IMPLEMENTATION)
    def test_locallyconnected_2d_regularization(
        self, data_format, padding, implementation
    ):
        num_samples = 2
        filters = 3
        stack_size = 4
        num_row = 6
        num_col = 7
        kwargs = {
            "filters": filters,
            "kernel_size": 3,
            "kernel_regularizer": "l2",
            "bias_regularizer": "l2",
            "activity_regularizer": "l2",
            "implementation": implementation,
            "padding": padding,
            "data_format": data_format,
        }

        if padding == "same" and implementation == 1:
            self.assertRaises(
                ValueError, keras.layers.LocallyConnected2D, **kwargs
            )
        else:
            with self.cached_session():
                layer = keras.layers.LocallyConnected2D(**kwargs)
                layer.build((num_samples, num_row, num_col, stack_size))
                self.assertLen(layer.losses, 2)
                layer(
                    keras.backend.variable(
                        np.ones((num_samples, num_row, num_col, stack_size))
                    )
                )
                self.assertLen(layer.losses, 3)

            k_constraint = keras.constraints.max_norm(0.01)
            b_constraint = keras.constraints.max_norm(0.01)
            kwargs = {
                "filters": filters,
                "kernel_size": 3,
                "kernel_constraint": k_constraint,
                "bias_constraint": b_constraint,
            }
            with self.cached_session():
                layer = keras.layers.LocallyConnected2D(**kwargs)
                layer.build((num_samples, num_row, num_col, stack_size))
                self.assertEqual(layer.kernel.constraint, k_constraint)
                self.assertEqual(layer.bias.constraint, b_constraint)

    def test_locallyconnected2d_invalid_output_shapes(self):
        kwargs = {"filters": 2, "kernel_size": 10}
        with self.assertRaisesRegex(
            ValueError, r"""One of the dimensions in the output is <= 0 """
        ):
            layer = keras.layers.LocallyConnected2D(**kwargs)
            layer.build((None, 5, 5, 2))


@test_combinations.generate(test_combinations.combine(mode=["graph", "eager"]))
class LocallyConnectedImplementationModeTest(
    tf.test.TestCase, parameterized.TestCase
):
    @parameterized.parameters(
        [
            {"width": 1, "data_format": "channels_first"},
            {"width": 1, "data_format": "channels_last"},
            {"width": 6, "data_format": "channels_first"},
            {"width": 6, "data_format": "channels_last"},
        ]
    )
    def test_locallyconnected_implementation(self, width, data_format):
        with self.cached_session():
            num_samples = 4
            num_classes = 3
            num_epochs = 2

            np.random.seed(1)
            tf_test_util.random_seed.set_seed(1)
            # Following code generates sparse targets and converts them
            # to one-hot encoded vectors
            # Create sparse targets eg. [0,1,2]
            sparse_targets = np.random.randint(0, num_classes, (num_samples,))

            # Convert to one-hot encoding
            # Final targets:
            # [[ 1. 0. 0. ]
            #  [ 0. 1. 0. ]
            #  [ 0. 0. 1. ]]

            targets = np.zeros((sparse_targets.size, num_classes))
            targets[np.arange(sparse_targets.size), sparse_targets] = 1
            height = 7
            filters = 2
            inputs = get_inputs(
                data_format, filters, height, num_samples, width
            )

            kernel_x = (3,)
            kernel_y = () if width == 1 else (2,)
            stride_x = (1,)
            stride_y = () if width == 1 else (3,)
            layers = 2

            kwargs = {
                "layers": layers,
                "filters": filters,
                "kernel_size": kernel_x + kernel_y,
                "strides": stride_x + stride_y,
                "data_format": data_format,
                "num_classes": num_classes,
            }

            model_1 = get_model(implementation=1, **kwargs)
            model_2 = get_model(implementation=2, **kwargs)
            model_3 = get_model(implementation=3, **kwargs)

            # Build models.
            model_1.train_on_batch(inputs, targets)
            model_2.train_on_batch(inputs, targets)
            model_3.train_on_batch(inputs, targets)

            # Copy weights.
            copy_model_weights(model_from=model_2, model_to=model_1)
            copy_model_weights(model_from=model_2, model_to=model_3)

            # Compare outputs at initialization.
            out_1 = model_1(inputs)
            out_2 = model_2(inputs)
            out_3 = model_3(inputs)

            self.assertAllCloseAccordingToType(
                out_2, out_1, rtol=1e-5, atol=1e-5
            )
            self.assertAllCloseAccordingToType(
                out_2, out_3, rtol=1e-5, atol=1e-5
            )
            self.assertAllCloseAccordingToType(
                out_1, out_3, rtol=1e-5, atol=1e-5
            )

            # Train.
            model_1.fit(
                x=inputs,
                y=targets,
                epochs=num_epochs,
                batch_size=num_samples,
                shuffle=False,
            )
            model_2.fit(
                x=inputs,
                y=targets,
                epochs=num_epochs,
                batch_size=num_samples,
                shuffle=False,
            )
            model_3.fit(
                x=inputs,
                y=targets,
                epochs=num_epochs,
                batch_size=num_samples,
                shuffle=False,
            )

            # Compare outputs after a few training steps.
            out_1 = model_1(inputs)
            out_2 = model_2(inputs)
            out_3 = model_3(inputs)

            self.assertAllCloseAccordingToType(out_2, out_1, atol=2e-4)
            self.assertAllCloseAccordingToType(out_2, out_3, atol=2e-4)
            self.assertAllCloseAccordingToType(out_1, out_3, atol=2e-4)

    @parameterized.parameters(
        [
            {"width": 1, "data_format": "channels_first"},
            {"width": 1, "data_format": "channels_last"},
            {"width": 6, "data_format": "channels_first"},
            {"width": 6, "data_format": "channels_last"},
        ]
    )
    def test_locallyconnected_save(self, width, data_format):
        with self.cached_session():
            num_samples = 4
            num_classes = 3
            num_epochs = 2

            np.random.seed(1)
            tf_test_util.random_seed.set_seed(1)
            # Following code generates sparse targets and converts them
            # to one-hot encoded vectors
            # Create sparse targets eg. [0,1,2]
            sparse_targets = np.random.randint(0, num_classes, (num_samples,))

            # Convert to one-hot encoding
            # Final targets:
            # [[ 1. 0. 0. ]
            #  [ 0. 1. 0. ]
            #  [ 0. 0. 1. ]]

            targets = np.zeros((sparse_targets.size, num_classes))
            targets[np.arange(sparse_targets.size), sparse_targets] = 1

            height = 7
            filters = 2
            inputs = get_inputs(
                data_format, filters, height, num_samples, width
            )

            kernel_x = (3,)
            kernel_y = () if width == 1 else (2,)
            stride_x = (1,)
            stride_y = () if width == 1 else (3,)
            layers = 2

            kwargs = {
                "layers": layers,
                "filters": filters,
                "kernel_size": kernel_x + kernel_y,
                "strides": stride_x + stride_y,
                "data_format": data_format,
                "num_classes": num_classes,
            }

            model_1 = get_model_saveable(implementation=1, **kwargs)
            model_2 = get_model_saveable(implementation=2, **kwargs)
            model_3 = get_model_saveable(implementation=3, **kwargs)

            # Train.
            model_1.fit(
                x=inputs,
                y=targets,
                epochs=num_epochs,
                batch_size=num_samples,
                shuffle=False,
            )
            model_2.fit(
                x=inputs,
                y=targets,
                epochs=num_epochs,
                batch_size=num_samples,
                shuffle=False,
            )
            model_3.fit(
                x=inputs,
                y=targets,
                epochs=num_epochs,
                batch_size=num_samples,
                shuffle=False,
            )

            out_1_before = model_1(inputs)
            out_2_before = model_2(inputs)
            out_3_before = model_3(inputs)

            path_1 = os.path.join(self.get_temp_dir(), "model_1_path")
            model_1.save(path_1)
            model_1 = keras.models.load_model(
                path_1, custom_objects={"xent": xent}
            )
            path_2 = os.path.join(self.get_temp_dir(), "model_2_path")
            model_2.save(path_2)
            model_2 = keras.models.load_model(
                path_2, custom_objects={"xent": xent}
            )
            path_3 = os.path.join(self.get_temp_dir(), "model_3_path")
            model_3.save(path_3)
            model_3 = keras.models.load_model(
                path_3, custom_objects={"xent": xent}
            )

            out_1_after = model_1(inputs)
            out_2_after = model_2(inputs)
            out_3_after = model_3(inputs)

            self.assertAllCloseAccordingToType(
                out_1_before, out_1_after, atol=2e-4
            )
            self.assertAllCloseAccordingToType(
                out_2_before, out_2_after, atol=2e-4
            )
            self.assertAllCloseAccordingToType(
                out_3_before, out_3_after, atol=2e-4
            )

    def test_make_2d(self):
        input_shapes = [
            (0,),
            (0, 0),
            (1,),
            (2,),
            (3,),
            (1, 0),
            (0, 3),
            (1, 1),
            (1, 2),
            (3, 1),
            (2, 2),
            (3, 3),
            (1, 0, 1),
            (5, 2, 3),
            (3, 5, 6, 7, 0),
            (3, 2, 2, 4, 4),
            (1, 2, 3, 4, 7, 2),
        ]
        np.random.seed(1)

        for input_shape in input_shapes:
            inputs = np.random.normal(0, 1, input_shape)
            inputs_tf = keras.backend.variable(inputs)

            split_dim = np.random.randint(0, inputs.ndim + 1)
            shape_2d = (
                int(np.prod(inputs.shape[:split_dim])),
                int(np.prod(inputs.shape[split_dim:])),
            )
            inputs_2d = np.reshape(inputs, shape_2d)

            inputs_2d_tf = locally_connected_utils.make_2d(inputs_tf, split_dim)
            inputs_2d_tf = keras.backend.get_value(inputs_2d_tf)

            self.assertAllCloseAccordingToType(inputs_2d, inputs_2d_tf)


def get_inputs(data_format, filters, height, num_samples, width):
    if data_format == "channels_first":
        if width == 1:
            input_shape = (filters, height)
        else:
            input_shape = (filters, height, width)

    elif data_format == "channels_last":
        if width == 1:
            input_shape = (height, filters)
        else:
            input_shape = (height, width, filters)

    else:
        raise NotImplementedError(data_format)

    inputs = np.random.normal(0, 1, (num_samples,) + input_shape).astype(
        np.float32
    )
    return inputs


def xent(y_true, y_pred):
    y_true = keras.backend.cast(keras.backend.reshape(y_true, (-1,)), tf.int32)

    return tf.compat.v1.nn.sparse_softmax_cross_entropy_with_logits(
        labels=y_true, logits=y_pred
    )


def get_model(
    implementation,
    filters,
    kernel_size,
    strides,
    layers,
    num_classes,
    data_format,
):
    model = keras.Sequential()

    if len(kernel_size) == 1:
        lc_layer = keras.layers.LocallyConnected1D
    elif len(kernel_size) == 2:
        lc_layer = keras.layers.LocallyConnected2D
    else:
        raise NotImplementedError(kernel_size)

    for _ in range(layers):
        model.add(
            lc_layer(
                padding="valid",
                kernel_initializer=keras.initializers.random_normal(),
                bias_initializer=keras.initializers.random_normal(),
                filters=filters,
                strides=strides,
                kernel_size=kernel_size,
                activation=keras.activations.relu,
                data_format=data_format,
                implementation=implementation,
            )
        )

    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(num_classes))
    model.compile(
        optimizer=RMSPropOptimizer(0.01),
        metrics=[keras.metrics.categorical_accuracy],
        loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    )
    return model


def get_model_saveable(
    implementation,
    filters,
    kernel_size,
    strides,
    layers,
    num_classes,
    data_format,
):
    model = keras.Sequential()

    if len(kernel_size) == 1:
        lc_layer = keras.layers.LocallyConnected1D
    elif len(kernel_size) == 2:
        lc_layer = keras.layers.LocallyConnected2D
    else:
        raise NotImplementedError(kernel_size)

    for _ in range(layers):
        model.add(
            lc_layer(
                padding="valid",
                kernel_initializer=keras.initializers.random_normal(),
                bias_initializer=keras.initializers.random_normal(),
                filters=filters,
                strides=strides,
                kernel_size=kernel_size,
                activation=keras.activations.relu,
                data_format=data_format,
                implementation=implementation,
            )
        )

    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(num_classes))
    model.compile(
        optimizer=rmsprop.RMSProp(learning_rate=0.01),
        metrics=[keras.metrics.categorical_accuracy],
        loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    )
    return model


def copy_lc_weights_2_to_1(lc_layer_2_from, lc_layer_1_to):
    lc_2_kernel, lc_2_bias = lc_layer_2_from.weights
    lc_2_kernel_masked = lc_2_kernel * lc_layer_2_from.kernel_mask

    data_format = lc_layer_2_from.data_format

    if data_format == "channels_first":
        if isinstance(lc_layer_2_from, keras.layers.LocallyConnected1D):
            permutation = (3, 0, 1, 2)
        elif isinstance(lc_layer_2_from, keras.layers.LocallyConnected2D):
            permutation = (4, 5, 0, 1, 2, 3)
        else:
            raise NotImplementedError(lc_layer_2_from)

    elif data_format == "channels_last":
        if isinstance(lc_layer_2_from, keras.layers.LocallyConnected1D):
            permutation = (2, 0, 1, 3)
        elif isinstance(lc_layer_2_from, keras.layers.LocallyConnected2D):
            permutation = (3, 4, 0, 1, 2, 5)
        else:
            raise NotImplementedError(lc_layer_2_from)

    else:
        raise NotImplementedError(data_format)

    lc_2_kernel_masked = keras.backend.permute_dimensions(
        lc_2_kernel_masked, permutation
    )

    lc_2_kernel_mask = tf.not_equal(lc_2_kernel_masked, 0)
    lc_2_kernel_flat = tf.compat.v1.boolean_mask(
        lc_2_kernel_masked, lc_2_kernel_mask
    )
    lc_2_kernel_reshaped = keras.backend.reshape(
        lc_2_kernel_flat, lc_layer_1_to.kernel.shape
    )

    lc_2_kernel_reshaped = keras.backend.get_value(lc_2_kernel_reshaped)
    lc_2_bias = keras.backend.get_value(lc_2_bias)

    lc_layer_1_to.set_weights([lc_2_kernel_reshaped, lc_2_bias])


def copy_lc_weights_2_to_3(lc_layer_2_from, lc_layer_3_to):
    lc_2_kernel, lc_2_bias = lc_layer_2_from.weights
    lc_2_kernel_masked = lc_2_kernel * lc_layer_2_from.kernel_mask

    lc_2_kernel_masked = locally_connected_utils.make_2d(
        lc_2_kernel_masked,
        split_dim=keras.backend.ndim(lc_2_kernel_masked) // 2,
    )
    lc_2_kernel_masked = keras.backend.transpose(lc_2_kernel_masked)
    lc_2_kernel_mask = tf.not_equal(lc_2_kernel_masked, 0)
    lc_2_kernel_flat = tf.compat.v1.boolean_mask(
        lc_2_kernel_masked, lc_2_kernel_mask
    )

    lc_2_kernel_flat = keras.backend.get_value(lc_2_kernel_flat)
    lc_2_bias = keras.backend.get_value(lc_2_bias)

    lc_layer_3_to.set_weights([lc_2_kernel_flat, lc_2_bias])


def copy_model_weights(model_from, model_to):
    for l in range(len(model_from.layers)):
        layer_from = model_from.layers[l]
        layer_to = model_to.layers[l]

        if isinstance(
            layer_from,
            (keras.layers.LocallyConnected2D, keras.layers.LocallyConnected1D),
        ) and isinstance(
            layer_to,
            (keras.layers.LocallyConnected2D, keras.layers.LocallyConnected1D),
        ):
            if layer_from.implementation == 2:
                if layer_to.implementation == 1:
                    copy_lc_weights_2_to_1(layer_from, layer_to)
                elif layer_to.implementation == 3:
                    copy_lc_weights_2_to_3(layer_from, layer_to)
                else:
                    raise NotImplementedError

            else:
                raise NotImplementedError

        elif isinstance(layer_from, keras.layers.Dense):
            weights_2, bias_2 = layer_from.weights
            weights_2 = keras.backend.get_value(weights_2)
            bias_2 = keras.backend.get_value(bias_2)
            layer_to.set_weights([weights_2, bias_2])

        else:
            continue


if __name__ == "__main__":
    tf.test.main()
